-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[AArch64][GlobalISel] Improve lowering of vector fp16 fptrunc #163398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][GlobalISel] Improve lowering of vector fp16 fptrunc #163398
Conversation
|
@llvm/pr-subscribers-backend-aarch64 Author: Ryan Cowan (HolyMolyCowMan) ChangesThis commit improves the lowering of vectors of fp16 when truncating and extending. Truncating has to be handled in a specific way to avoid double rounding. Patch is 71.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163398.diff 15 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index ecaeff77fcb4b..0c71844e3a73e 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -333,6 +333,13 @@ def combine_mul_cmlt : GICombineRule<
(apply [{ applyCombineMulCMLT(*${root}, MRI, B, ${matchinfo}); }])
>;
+def lower_fptrunc_fptrunc: GICombineRule<
+ (defs root:$root),
+ (match (wip_match_opcode G_FPTRUNC):$root,
+ [{ return matchFpTruncFpTrunc(*${root}, MRI); }]),
+ (apply [{ applyFpTruncFpTrunc(*${root}, MRI, B); }])
+>;
+
// Post-legalization combines which should happen at all optimization levels.
// (E.g. ones that facilitate matching for the selector) For example, matching
// pseudos.
@@ -341,7 +348,7 @@ def AArch64PostLegalizerLowering
[shuffle_vector_lowering, vashr_vlshr_imm,
icmp_lowering, build_vector_lowering,
lower_vector_fcmp, form_truncstore, fconstant_to_constant,
- vector_sext_inreg_to_shift,
+ vector_sext_inreg_to_shift, lower_fptrunc_fptrunc,
unmerge_ext_to_unmerge, lower_mulv2s64,
vector_unmerge_lowering, insertelt_nonconst,
unmerge_duplanes]> {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 9e2d698e04ae7..fde86449a76a7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -21,6 +21,7 @@
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/DerivedTypes.h"
@@ -817,14 +818,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.legalFor(
{{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}})
.libcallFor({{s16, s128}, {s32, s128}, {s64, s128}})
- .clampNumElements(0, v4s16, v4s16)
- .clampNumElements(0, v2s32, v2s32)
+ .moreElementsToNextPow2(1)
+ .customIf([](const LegalityQuery &Q) {
+ LLT DstTy = Q.Types[0];
+ LLT SrcTy = Q.Types[1];
+ return SrcTy.isFixedVector() && DstTy.isFixedVector() &&
+ SrcTy.getScalarSizeInBits() == 64 &&
+ DstTy.getScalarSizeInBits() == 16;
+ })
+ // Clamp based on input
+ .clampNumElements(1, v4s32, v4s32)
+ .clampNumElements(1, v2s64, v2s64)
.scalarize(0);
getActionDefinitionsBuilder(G_FPEXT)
.legalFor(
{{s32, s16}, {s64, s16}, {s64, s32}, {v4s32, v4s16}, {v2s64, v2s32}})
.libcallFor({{s128, s64}, {s128, s32}, {s128, s16}})
+ .moreElementsToNextPow2(0)
+ .customIf([](const LegalityQuery &Q) {
+ LLT DstTy = Q.Types[0];
+ LLT SrcTy = Q.Types[1];
+ return SrcTy.isVector() && DstTy.isVector() &&
+ SrcTy.getScalarSizeInBits() == 16 &&
+ DstTy.getScalarSizeInBits() == 64;
+ })
.clampNumElements(0, v4s32, v4s32)
.clampNumElements(0, v2s64, v2s64)
.scalarize(0);
@@ -1472,6 +1490,12 @@ bool AArch64LegalizerInfo::legalizeCustom(
return legalizeICMP(MI, MRI, MIRBuilder);
case TargetOpcode::G_BITCAST:
return legalizeBitcast(MI, Helper);
+ case TargetOpcode::G_FPEXT:
+ // In order to vectorise f16 to f64 properly, we need to use f32 as an
+ // intermediary
+ return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPEXT);
+ case TargetOpcode::G_FPTRUNC:
+ return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPTRUNC);
}
llvm_unreachable("expected switch to return");
@@ -2396,3 +2420,37 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI,
MI.eraseFromParent();
return true;
}
+
+bool AArch64LegalizerInfo::legalizeViaF32(MachineInstr &MI,
+ MachineIRBuilder &MIRBuilder,
+ MachineRegisterInfo &MRI,
+ unsigned Opcode) const {
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+ LLT DstTy = MRI.getType(Dst);
+ LLT SrcTy = MRI.getType(Src);
+
+ LLT MidTy = LLT::fixed_vector(SrcTy.getNumElements(), LLT::scalar(32));
+
+ MachineInstrBuilder Mid;
+ MachineInstrBuilder Fin;
+ MIRBuilder.setInstrAndDebugLoc(MI);
+ switch (Opcode) {
+ default:
+ return false;
+ case TargetOpcode::G_FPEXT: {
+ Mid = MIRBuilder.buildFPExt(MidTy, Src);
+ Fin = MIRBuilder.buildFPExt(DstTy, Mid.getReg(0));
+ break;
+ }
+ case TargetOpcode::G_FPTRUNC: {
+ Mid = MIRBuilder.buildFPTrunc(MidTy, Src);
+ Fin = MIRBuilder.buildFPTrunc(DstTy, Mid.getReg(0));
+ break;
+ }
+ }
+
+ MRI.replaceRegWith(Dst, Fin.getReg(0));
+ MI.eraseFromParent();
+ return true;
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
index bcb294326fa92..049808d66f983 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
@@ -67,6 +67,8 @@ class AArch64LegalizerInfo : public LegalizerInfo {
bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const;
+ bool legalizeViaF32(MachineInstr &MI, MachineIRBuilder &MIRBuilder,
+ MachineRegisterInfo &MRI, unsigned Opcode) const;
const AArch64Subtarget *ST;
};
} // End llvm namespace.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index 23dcaea2ac1a4..30417148a5a00 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -901,6 +901,197 @@ unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
return 0;
}
+// Helper function for matchFpTruncFpTrunc.
+// Checks that the given definition belongs to an FPTRUNC and that the source is
+// not an integer, as no rounding is necessary due to the range of values
+bool checkTruncSrc(MachineRegisterInfo &MRI, MachineInstr *MaybeFpTrunc) {
+ if (!MaybeFpTrunc || MaybeFpTrunc->getOpcode() != TargetOpcode::G_FPTRUNC)
+ return false;
+
+ // Check the source is 64 bits as we only want to match a very specific
+ // pattern
+ Register FpTruncSrc = MaybeFpTrunc->getOperand(1).getReg();
+ LLT SrcTy = MRI.getType(FpTruncSrc);
+ if (SrcTy.getScalarSizeInBits() != 64)
+ return false;
+
+ // Need to check the float didn't come from an int as no rounding is
+ // neccessary
+ MachineInstr *FpTruncSrcDef = getDefIgnoringCopies(FpTruncSrc, MRI);
+ if (FpTruncSrcDef->getOpcode() == TargetOpcode::G_SITOFP ||
+ FpTruncSrcDef->getOpcode() == TargetOpcode::G_UITOFP)
+ return false;
+
+ return true;
+}
+
+// To avoid double rounding issues we need to lower FPTRUNC(FPTRUNC) to an odd
+// rounding truncate and a normal truncate. When
+// truncating an FP that came from an integer this is not a problem as the range
+// of values is lower in the int
+bool matchFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI) {
+ if (MI.getOpcode() != TargetOpcode::G_FPTRUNC)
+ return false;
+
+ // Check the destination is 16 bits as we only want to match a very specific
+ // pattern
+ Register Dst = MI.getOperand(0).getReg();
+ LLT DstTy = MRI.getType(Dst);
+ if (DstTy.getScalarSizeInBits() != 16)
+ return false;
+
+ Register Src = MI.getOperand(1).getReg();
+
+ MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+ if (!ParentDef)
+ return false;
+
+ MachineInstr *FpTruncDef;
+ switch (ParentDef->getOpcode()) {
+ default:
+ return false;
+ case TargetOpcode::G_CONCAT_VECTORS: {
+ // Expecting exactly two FPTRUNCs
+ if (ParentDef->getNumOperands() != 3)
+ return false;
+
+ // All operands need to be FPTRUNC
+ for (unsigned OpIdx = 1, NumOperands = ParentDef->getNumOperands();
+ OpIdx != NumOperands; ++OpIdx) {
+ Register FpTruncDst = ParentDef->getOperand(OpIdx).getReg();
+
+ FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ if (!checkTruncSrc(MRI, FpTruncDef))
+ return false;
+ }
+
+ return true;
+ }
+ // This is to match cases in which vectors are widened to a larger size
+ case TargetOpcode::G_INSERT_VECTOR_ELT: {
+ Register VecExtractDst = ParentDef->getOperand(2).getReg();
+ MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+ Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+ FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ if (!checkTruncSrc(MRI, FpTruncDef))
+ return false;
+ break;
+ }
+ case TargetOpcode::G_FPTRUNC: {
+ Register FpTruncDst = ParentDef->getOperand(1).getReg();
+ FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ if (!checkTruncSrc(MRI, FpTruncDef))
+ return false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+void applyFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B) {
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+
+ LLT V2F32 = LLT::fixed_vector(2, LLT::scalar(32));
+ LLT V4F32 = LLT::fixed_vector(4, LLT::scalar(32));
+ LLT V4F16 = LLT::fixed_vector(4, LLT::scalar(16));
+
+ B.setInstrAndDebugLoc(MI);
+
+ MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+ if (!ParentDef)
+ return;
+
+ switch (ParentDef->getOpcode()) {
+ default:
+ return;
+ case TargetOpcode::G_INSERT_VECTOR_ELT: {
+ Register VecExtractDst = ParentDef->getOperand(2).getReg();
+ MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+ Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+ MachineInstr *FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ Register FpTruncSrc = FpTruncDef->getOperand(1).getReg();
+ MRI.setRegClass(FpTruncSrc, &AArch64::FPR128RegClass);
+
+ Register Fp32 = MRI.createGenericVirtualRegister(V2F32);
+ MRI.setRegClass(Fp32, &AArch64::FPR64RegClass);
+
+ B.buildInstr(AArch64::FCVTXNv2f32, {Fp32}, {FpTruncSrc});
+
+ // Only 4f32 -> 4f16 is legal so we need to mimic that situation
+ Register Fp32Padding = B.buildUndef(V2F32).getReg(0);
+ MRI.setRegClass(Fp32Padding, &AArch64::FPR64RegClass);
+
+ Register Fp32Full = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(Fp32Full, &AArch64::FPR128RegClass);
+ B.buildConcatVectors(Fp32Full, {Fp32, Fp32Padding});
+
+ Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+ MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+ B.buildFPTrunc(Fp16, Fp32Full);
+
+ MRI.replaceRegWith(Dst, Fp16);
+ MI.eraseFromParent();
+ break;
+ }
+ case TargetOpcode::G_CONCAT_VECTORS: {
+ // Get the two FP Truncs that are being concatenated
+ Register FpTrunc1Dst = ParentDef->getOperand(1).getReg();
+ Register FpTrunc2Dst = ParentDef->getOperand(2).getReg();
+
+ MachineInstr *FpTrunc1Def = getDefIgnoringCopies(FpTrunc1Dst, MRI);
+ MachineInstr *FpTrunc2Def = getDefIgnoringCopies(FpTrunc2Dst, MRI);
+
+ // Make the registers 128bit to store the 2 doubles
+ Register LoFp64 = FpTrunc1Def->getOperand(1).getReg();
+ MRI.setRegClass(LoFp64, &AArch64::FPR128RegClass);
+ Register HiFp64 = FpTrunc2Def->getOperand(1).getReg();
+ MRI.setRegClass(HiFp64, &AArch64::FPR128RegClass);
+
+ B.setInstrAndDebugLoc(MI);
+
+ // Convert the lower half
+ Register LoFp32 = MRI.createGenericVirtualRegister(V2F32);
+ MRI.setRegClass(LoFp32, &AArch64::FPR64RegClass);
+ B.buildInstr(AArch64::FCVTXNv2f32, {LoFp32}, {LoFp64});
+
+ // Create a register for the high half to use
+ Register AccUndef = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(AccUndef, &AArch64::FPR128RegClass);
+ B.buildUndef(AccUndef);
+
+ Register Acc = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(Acc, &AArch64::FPR128RegClass);
+ B.buildInstr(TargetOpcode::INSERT_SUBREG)
+ .addDef(Acc)
+ .addUse(AccUndef)
+ .addUse(LoFp32)
+ .addImm(AArch64::dsub);
+
+ // Convert the high half
+ Register AccOut = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(AccOut, &AArch64::FPR128RegClass);
+ B.buildInstr(AArch64::FCVTXNv4f32).addDef(AccOut).addUse(Acc).addUse(HiFp64);
+
+ Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+ MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+ B.buildFPTrunc(Fp16, AccOut);
+
+ MRI.replaceRegWith(Dst, Fp16);
+ MI.eraseFromParent();
+ break;
+ }
+ }
+}
+
/// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
/// instruction \p MI.
bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index 896603d6eb20d..0561f91b6e015 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -555,11 +555,11 @@
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_FPEXT (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_FPTRUNC (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_FPTOSI (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
diff --git a/llvm/test/CodeGen/AArch64/arm64-fp128.ll b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
index 3e4b887fed55d..b8b8d20b9a17b 100644
--- a/llvm/test/CodeGen/AArch64/arm64-fp128.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
@@ -1197,30 +1197,22 @@ define <2 x half> @vec_round_f16(<2 x fp128> %val) {
;
; CHECK-GI-LABEL: vec_round_f16:
; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: sub sp, sp, #64
-; CHECK-GI-NEXT: str x30, [sp, #48] // 8-byte Folded Spill
-; CHECK-GI-NEXT: .cfi_def_cfa_offset 64
+; CHECK-GI-NEXT: sub sp, sp, #48
+; CHECK-GI-NEXT: str x30, [sp, #32] // 8-byte Folded Spill
+; CHECK-GI-NEXT: .cfi_def_cfa_offset 48
; CHECK-GI-NEXT: .cfi_offset w30, -16
-; CHECK-GI-NEXT: mov v2.d[0], x8
; CHECK-GI-NEXT: str q1, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT: mov v2.d[1], x8
-; CHECK-GI-NEXT: str q2, [sp, #32] // 16-byte Folded Spill
; CHECK-GI-NEXT: bl __trunctfhf2
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
; CHECK-GI-NEXT: str q0, [sp, #16] // 16-byte Folded Spill
; CHECK-GI-NEXT: ldr q0, [sp] // 16-byte Folded Reload
; CHECK-GI-NEXT: bl __trunctfhf2
+; CHECK-GI-NEXT: ldr q1, [sp, #16] // 16-byte Folded Reload
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
-; CHECK-GI-NEXT: str q0, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT: bl __trunctfhf2
-; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT: bl __trunctfhf2
-; CHECK-GI-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
-; CHECK-GI-NEXT: ldr x30, [sp, #48] // 8-byte Folded Reload
-; CHECK-GI-NEXT: mov v0.h[1], v1.h[0]
-; CHECK-GI-NEXT: // kill: def $d0 killed $d0 killed $q0
-; CHECK-GI-NEXT: add sp, sp, #64
+; CHECK-GI-NEXT: ldr x30, [sp, #32] // 8-byte Folded Reload
+; CHECK-GI-NEXT: mov v1.h[1], v0.h[0]
+; CHECK-GI-NEXT: fmov d0, d1
+; CHECK-GI-NEXT: add sp, sp, #48
; CHECK-GI-NEXT: ret
%dst = fptrunc <2 x fp128> %val to <2 x half>
ret <2 x half> %dst
diff --git a/llvm/test/CodeGen/AArch64/fmla.ll b/llvm/test/CodeGen/AArch64/fmla.ll
index a37aabb0b5384..12b6562b5cf0c 100644
--- a/llvm/test/CodeGen/AArch64/fmla.ll
+++ b/llvm/test/CodeGen/AArch64/fmla.ll
@@ -865,22 +865,22 @@ define <7 x half> @fmuladd_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
; CHECK-GI-NOFP16-NEXT: fcvtl v0.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v2.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT: mov v5.h[0], v2.h[4]
; CHECK-GI-NOFP16-NEXT: fcvtl v4.4s, v4.4h
; CHECK-GI-NOFP16-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT: fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT: mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT: fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT: fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v4.h[2]
; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v4.h[3]
; CHECK-GI-NOFP16-NEXT: mov v0.h[4], v1.h[0]
; CHECK-GI-NOFP16-NEXT: mov v0.h[5], v1.h[1]
; CHECK-GI-NOFP16-NEXT: mov v0.h[6], v1.h[2]
@@ -1350,22 +1350,22 @@ define <7 x half> @fmul_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
; CHECK-GI-NOFP16-NEXT: fcvtl v0.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v2.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT: mov v5.h[0], v2.h[4]
; CHECK-GI-NOFP16-NEXT: fcvtl v4.4s, v4.4h
; CHECK-GI-NOFP16-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT: fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT: mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT: fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT: fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v4.h[2]
; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v4.h[3]
;...
[truncated]
|
|
@llvm/pr-subscribers-llvm-globalisel Author: Ryan Cowan (HolyMolyCowMan) ChangesThis commit improves the lowering of vectors of fp16 when truncating and extending. Truncating has to be handled in a specific way to avoid double rounding. Patch is 71.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163398.diff 15 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index ecaeff77fcb4b..0c71844e3a73e 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -333,6 +333,13 @@ def combine_mul_cmlt : GICombineRule<
(apply [{ applyCombineMulCMLT(*${root}, MRI, B, ${matchinfo}); }])
>;
+def lower_fptrunc_fptrunc: GICombineRule<
+ (defs root:$root),
+ (match (wip_match_opcode G_FPTRUNC):$root,
+ [{ return matchFpTruncFpTrunc(*${root}, MRI); }]),
+ (apply [{ applyFpTruncFpTrunc(*${root}, MRI, B); }])
+>;
+
// Post-legalization combines which should happen at all optimization levels.
// (E.g. ones that facilitate matching for the selector) For example, matching
// pseudos.
@@ -341,7 +348,7 @@ def AArch64PostLegalizerLowering
[shuffle_vector_lowering, vashr_vlshr_imm,
icmp_lowering, build_vector_lowering,
lower_vector_fcmp, form_truncstore, fconstant_to_constant,
- vector_sext_inreg_to_shift,
+ vector_sext_inreg_to_shift, lower_fptrunc_fptrunc,
unmerge_ext_to_unmerge, lower_mulv2s64,
vector_unmerge_lowering, insertelt_nonconst,
unmerge_duplanes]> {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 9e2d698e04ae7..fde86449a76a7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -21,6 +21,7 @@
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/DerivedTypes.h"
@@ -817,14 +818,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.legalFor(
{{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}})
.libcallFor({{s16, s128}, {s32, s128}, {s64, s128}})
- .clampNumElements(0, v4s16, v4s16)
- .clampNumElements(0, v2s32, v2s32)
+ .moreElementsToNextPow2(1)
+ .customIf([](const LegalityQuery &Q) {
+ LLT DstTy = Q.Types[0];
+ LLT SrcTy = Q.Types[1];
+ return SrcTy.isFixedVector() && DstTy.isFixedVector() &&
+ SrcTy.getScalarSizeInBits() == 64 &&
+ DstTy.getScalarSizeInBits() == 16;
+ })
+ // Clamp based on input
+ .clampNumElements(1, v4s32, v4s32)
+ .clampNumElements(1, v2s64, v2s64)
.scalarize(0);
getActionDefinitionsBuilder(G_FPEXT)
.legalFor(
{{s32, s16}, {s64, s16}, {s64, s32}, {v4s32, v4s16}, {v2s64, v2s32}})
.libcallFor({{s128, s64}, {s128, s32}, {s128, s16}})
+ .moreElementsToNextPow2(0)
+ .customIf([](const LegalityQuery &Q) {
+ LLT DstTy = Q.Types[0];
+ LLT SrcTy = Q.Types[1];
+ return SrcTy.isVector() && DstTy.isVector() &&
+ SrcTy.getScalarSizeInBits() == 16 &&
+ DstTy.getScalarSizeInBits() == 64;
+ })
.clampNumElements(0, v4s32, v4s32)
.clampNumElements(0, v2s64, v2s64)
.scalarize(0);
@@ -1472,6 +1490,12 @@ bool AArch64LegalizerInfo::legalizeCustom(
return legalizeICMP(MI, MRI, MIRBuilder);
case TargetOpcode::G_BITCAST:
return legalizeBitcast(MI, Helper);
+ case TargetOpcode::G_FPEXT:
+ // In order to vectorise f16 to f64 properly, we need to use f32 as an
+ // intermediary
+ return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPEXT);
+ case TargetOpcode::G_FPTRUNC:
+ return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPTRUNC);
}
llvm_unreachable("expected switch to return");
@@ -2396,3 +2420,37 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI,
MI.eraseFromParent();
return true;
}
+
+bool AArch64LegalizerInfo::legalizeViaF32(MachineInstr &MI,
+ MachineIRBuilder &MIRBuilder,
+ MachineRegisterInfo &MRI,
+ unsigned Opcode) const {
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+ LLT DstTy = MRI.getType(Dst);
+ LLT SrcTy = MRI.getType(Src);
+
+ LLT MidTy = LLT::fixed_vector(SrcTy.getNumElements(), LLT::scalar(32));
+
+ MachineInstrBuilder Mid;
+ MachineInstrBuilder Fin;
+ MIRBuilder.setInstrAndDebugLoc(MI);
+ switch (Opcode) {
+ default:
+ return false;
+ case TargetOpcode::G_FPEXT: {
+ Mid = MIRBuilder.buildFPExt(MidTy, Src);
+ Fin = MIRBuilder.buildFPExt(DstTy, Mid.getReg(0));
+ break;
+ }
+ case TargetOpcode::G_FPTRUNC: {
+ Mid = MIRBuilder.buildFPTrunc(MidTy, Src);
+ Fin = MIRBuilder.buildFPTrunc(DstTy, Mid.getReg(0));
+ break;
+ }
+ }
+
+ MRI.replaceRegWith(Dst, Fin.getReg(0));
+ MI.eraseFromParent();
+ return true;
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
index bcb294326fa92..049808d66f983 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
@@ -67,6 +67,8 @@ class AArch64LegalizerInfo : public LegalizerInfo {
bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const;
+ bool legalizeViaF32(MachineInstr &MI, MachineIRBuilder &MIRBuilder,
+ MachineRegisterInfo &MRI, unsigned Opcode) const;
const AArch64Subtarget *ST;
};
} // End llvm namespace.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index 23dcaea2ac1a4..30417148a5a00 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -901,6 +901,197 @@ unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
return 0;
}
+// Helper function for matchFpTruncFpTrunc.
+// Checks that the given definition belongs to an FPTRUNC and that the source is
+// not an integer, as no rounding is necessary due to the range of values
+bool checkTruncSrc(MachineRegisterInfo &MRI, MachineInstr *MaybeFpTrunc) {
+ if (!MaybeFpTrunc || MaybeFpTrunc->getOpcode() != TargetOpcode::G_FPTRUNC)
+ return false;
+
+ // Check the source is 64 bits as we only want to match a very specific
+ // pattern
+ Register FpTruncSrc = MaybeFpTrunc->getOperand(1).getReg();
+ LLT SrcTy = MRI.getType(FpTruncSrc);
+ if (SrcTy.getScalarSizeInBits() != 64)
+ return false;
+
+ // Need to check the float didn't come from an int as no rounding is
+ // neccessary
+ MachineInstr *FpTruncSrcDef = getDefIgnoringCopies(FpTruncSrc, MRI);
+ if (FpTruncSrcDef->getOpcode() == TargetOpcode::G_SITOFP ||
+ FpTruncSrcDef->getOpcode() == TargetOpcode::G_UITOFP)
+ return false;
+
+ return true;
+}
+
+// To avoid double rounding issues we need to lower FPTRUNC(FPTRUNC) to an odd
+// rounding truncate and a normal truncate. When
+// truncating an FP that came from an integer this is not a problem as the range
+// of values is lower in the int
+bool matchFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI) {
+ if (MI.getOpcode() != TargetOpcode::G_FPTRUNC)
+ return false;
+
+ // Check the destination is 16 bits as we only want to match a very specific
+ // pattern
+ Register Dst = MI.getOperand(0).getReg();
+ LLT DstTy = MRI.getType(Dst);
+ if (DstTy.getScalarSizeInBits() != 16)
+ return false;
+
+ Register Src = MI.getOperand(1).getReg();
+
+ MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+ if (!ParentDef)
+ return false;
+
+ MachineInstr *FpTruncDef;
+ switch (ParentDef->getOpcode()) {
+ default:
+ return false;
+ case TargetOpcode::G_CONCAT_VECTORS: {
+ // Expecting exactly two FPTRUNCs
+ if (ParentDef->getNumOperands() != 3)
+ return false;
+
+ // All operands need to be FPTRUNC
+ for (unsigned OpIdx = 1, NumOperands = ParentDef->getNumOperands();
+ OpIdx != NumOperands; ++OpIdx) {
+ Register FpTruncDst = ParentDef->getOperand(OpIdx).getReg();
+
+ FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ if (!checkTruncSrc(MRI, FpTruncDef))
+ return false;
+ }
+
+ return true;
+ }
+ // This is to match cases in which vectors are widened to a larger size
+ case TargetOpcode::G_INSERT_VECTOR_ELT: {
+ Register VecExtractDst = ParentDef->getOperand(2).getReg();
+ MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+ Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+ FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ if (!checkTruncSrc(MRI, FpTruncDef))
+ return false;
+ break;
+ }
+ case TargetOpcode::G_FPTRUNC: {
+ Register FpTruncDst = ParentDef->getOperand(1).getReg();
+ FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ if (!checkTruncSrc(MRI, FpTruncDef))
+ return false;
+ break;
+ }
+ }
+
+ return true;
+}
+
+void applyFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B) {
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+
+ LLT V2F32 = LLT::fixed_vector(2, LLT::scalar(32));
+ LLT V4F32 = LLT::fixed_vector(4, LLT::scalar(32));
+ LLT V4F16 = LLT::fixed_vector(4, LLT::scalar(16));
+
+ B.setInstrAndDebugLoc(MI);
+
+ MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
+ if (!ParentDef)
+ return;
+
+ switch (ParentDef->getOpcode()) {
+ default:
+ return;
+ case TargetOpcode::G_INSERT_VECTOR_ELT: {
+ Register VecExtractDst = ParentDef->getOperand(2).getReg();
+ MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
+
+ Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
+ MachineInstr *FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
+
+ Register FpTruncSrc = FpTruncDef->getOperand(1).getReg();
+ MRI.setRegClass(FpTruncSrc, &AArch64::FPR128RegClass);
+
+ Register Fp32 = MRI.createGenericVirtualRegister(V2F32);
+ MRI.setRegClass(Fp32, &AArch64::FPR64RegClass);
+
+ B.buildInstr(AArch64::FCVTXNv2f32, {Fp32}, {FpTruncSrc});
+
+ // Only 4f32 -> 4f16 is legal so we need to mimic that situation
+ Register Fp32Padding = B.buildUndef(V2F32).getReg(0);
+ MRI.setRegClass(Fp32Padding, &AArch64::FPR64RegClass);
+
+ Register Fp32Full = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(Fp32Full, &AArch64::FPR128RegClass);
+ B.buildConcatVectors(Fp32Full, {Fp32, Fp32Padding});
+
+ Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+ MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+ B.buildFPTrunc(Fp16, Fp32Full);
+
+ MRI.replaceRegWith(Dst, Fp16);
+ MI.eraseFromParent();
+ break;
+ }
+ case TargetOpcode::G_CONCAT_VECTORS: {
+ // Get the two FP Truncs that are being concatenated
+ Register FpTrunc1Dst = ParentDef->getOperand(1).getReg();
+ Register FpTrunc2Dst = ParentDef->getOperand(2).getReg();
+
+ MachineInstr *FpTrunc1Def = getDefIgnoringCopies(FpTrunc1Dst, MRI);
+ MachineInstr *FpTrunc2Def = getDefIgnoringCopies(FpTrunc2Dst, MRI);
+
+ // Make the registers 128bit to store the 2 doubles
+ Register LoFp64 = FpTrunc1Def->getOperand(1).getReg();
+ MRI.setRegClass(LoFp64, &AArch64::FPR128RegClass);
+ Register HiFp64 = FpTrunc2Def->getOperand(1).getReg();
+ MRI.setRegClass(HiFp64, &AArch64::FPR128RegClass);
+
+ B.setInstrAndDebugLoc(MI);
+
+ // Convert the lower half
+ Register LoFp32 = MRI.createGenericVirtualRegister(V2F32);
+ MRI.setRegClass(LoFp32, &AArch64::FPR64RegClass);
+ B.buildInstr(AArch64::FCVTXNv2f32, {LoFp32}, {LoFp64});
+
+ // Create a register for the high half to use
+ Register AccUndef = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(AccUndef, &AArch64::FPR128RegClass);
+ B.buildUndef(AccUndef);
+
+ Register Acc = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(Acc, &AArch64::FPR128RegClass);
+ B.buildInstr(TargetOpcode::INSERT_SUBREG)
+ .addDef(Acc)
+ .addUse(AccUndef)
+ .addUse(LoFp32)
+ .addImm(AArch64::dsub);
+
+ // Convert the high half
+ Register AccOut = MRI.createGenericVirtualRegister(V4F32);
+ MRI.setRegClass(AccOut, &AArch64::FPR128RegClass);
+ B.buildInstr(AArch64::FCVTXNv4f32).addDef(AccOut).addUse(Acc).addUse(HiFp64);
+
+ Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
+ MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
+ B.buildFPTrunc(Fp16, AccOut);
+
+ MRI.replaceRegWith(Dst, Fp16);
+ MI.eraseFromParent();
+ break;
+ }
+ }
+}
+
/// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
/// instruction \p MI.
bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index 896603d6eb20d..0561f91b6e015 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -555,11 +555,11 @@
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_FPEXT (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_FPTRUNC (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
-# DEBUG-NEXT: .. the first uncovered type index: 2, OK
-# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
+# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_FPTOSI (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
diff --git a/llvm/test/CodeGen/AArch64/arm64-fp128.ll b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
index 3e4b887fed55d..b8b8d20b9a17b 100644
--- a/llvm/test/CodeGen/AArch64/arm64-fp128.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-fp128.ll
@@ -1197,30 +1197,22 @@ define <2 x half> @vec_round_f16(<2 x fp128> %val) {
;
; CHECK-GI-LABEL: vec_round_f16:
; CHECK-GI: // %bb.0:
-; CHECK-GI-NEXT: sub sp, sp, #64
-; CHECK-GI-NEXT: str x30, [sp, #48] // 8-byte Folded Spill
-; CHECK-GI-NEXT: .cfi_def_cfa_offset 64
+; CHECK-GI-NEXT: sub sp, sp, #48
+; CHECK-GI-NEXT: str x30, [sp, #32] // 8-byte Folded Spill
+; CHECK-GI-NEXT: .cfi_def_cfa_offset 48
; CHECK-GI-NEXT: .cfi_offset w30, -16
-; CHECK-GI-NEXT: mov v2.d[0], x8
; CHECK-GI-NEXT: str q1, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT: mov v2.d[1], x8
-; CHECK-GI-NEXT: str q2, [sp, #32] // 16-byte Folded Spill
; CHECK-GI-NEXT: bl __trunctfhf2
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
; CHECK-GI-NEXT: str q0, [sp, #16] // 16-byte Folded Spill
; CHECK-GI-NEXT: ldr q0, [sp] // 16-byte Folded Reload
; CHECK-GI-NEXT: bl __trunctfhf2
+; CHECK-GI-NEXT: ldr q1, [sp, #16] // 16-byte Folded Reload
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
-; CHECK-GI-NEXT: str q0, [sp] // 16-byte Folded Spill
-; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT: bl __trunctfhf2
-; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
-; CHECK-GI-NEXT: bl __trunctfhf2
-; CHECK-GI-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
-; CHECK-GI-NEXT: ldr x30, [sp, #48] // 8-byte Folded Reload
-; CHECK-GI-NEXT: mov v0.h[1], v1.h[0]
-; CHECK-GI-NEXT: // kill: def $d0 killed $d0 killed $q0
-; CHECK-GI-NEXT: add sp, sp, #64
+; CHECK-GI-NEXT: ldr x30, [sp, #32] // 8-byte Folded Reload
+; CHECK-GI-NEXT: mov v1.h[1], v0.h[0]
+; CHECK-GI-NEXT: fmov d0, d1
+; CHECK-GI-NEXT: add sp, sp, #48
; CHECK-GI-NEXT: ret
%dst = fptrunc <2 x fp128> %val to <2 x half>
ret <2 x half> %dst
diff --git a/llvm/test/CodeGen/AArch64/fmla.ll b/llvm/test/CodeGen/AArch64/fmla.ll
index a37aabb0b5384..12b6562b5cf0c 100644
--- a/llvm/test/CodeGen/AArch64/fmla.ll
+++ b/llvm/test/CodeGen/AArch64/fmla.ll
@@ -865,22 +865,22 @@ define <7 x half> @fmuladd_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
; CHECK-GI-NOFP16-NEXT: fcvtl v0.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v2.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT: mov v5.h[0], v2.h[4]
; CHECK-GI-NOFP16-NEXT: fcvtl v4.4s, v4.4h
; CHECK-GI-NOFP16-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT: fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT: mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT: fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT: fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v4.h[2]
; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v4.h[3]
; CHECK-GI-NOFP16-NEXT: mov v0.h[4], v1.h[0]
; CHECK-GI-NOFP16-NEXT: mov v0.h[5], v1.h[1]
; CHECK-GI-NOFP16-NEXT: mov v0.h[6], v1.h[2]
@@ -1350,22 +1350,22 @@ define <7 x half> @fmul_v7f16(<7 x half> %a, <7 x half> %b, <7 x half> %c) {
; CHECK-GI-NOFP16-NEXT: fcvtl v0.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v2.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v3.4s, v5.4h
-; CHECK-GI-NOFP16-NEXT: mov v5.h[0], v2.h[4]
; CHECK-GI-NOFP16-NEXT: fcvtl v4.4s, v4.4h
; CHECK-GI-NOFP16-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[1], v2.h[5]
-; CHECK-GI-NOFP16-NEXT: fmul v1.4s, v3.4s, v4.4s
-; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v0.4s
-; CHECK-GI-NOFP16-NEXT: mov v5.h[2], v2.h[6]
-; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v3.h[0]
-; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v5.4h
+; CHECK-GI-NOFP16-NEXT: mov v1.h[0], v2.h[4]
+; CHECK-GI-NOFP16-NEXT: fmul v3.4s, v3.4s, v4.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[1], v2.h[5]
+; CHECK-GI-NOFP16-NEXT: fcvtn v4.4h, v0.4s
+; CHECK-GI-NOFP16-NEXT: fcvtn v3.4h, v3.4s
+; CHECK-GI-NOFP16-NEXT: mov v1.h[2], v2.h[6]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[0], v4.h[0]
+; CHECK-GI-NOFP16-NEXT: fcvtl v2.4s, v3.4h
; CHECK-GI-NOFP16-NEXT: fcvtl v1.4s, v1.4h
-; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v3.h[1]
-; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v1.4s, v2.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v3.h[2]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[1], v4.h[1]
+; CHECK-GI-NOFP16-NEXT: fadd v1.4s, v2.4s, v1.4s
+; CHECK-GI-NOFP16-NEXT: mov v0.h[2], v4.h[2]
; CHECK-GI-NOFP16-NEXT: fcvtn v1.4h, v1.4s
-; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v3.h[3]
+; CHECK-GI-NOFP16-NEXT: mov v0.h[3], v4.h[3]
;...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
7ad3118 to
e77ef45
Compare
|
I'm not 100% sure that the pass I have included this optimisation in is the correct one, any thoughts on this are more than welcome. |
c-rhodes
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the patch Ryan, I've not gone thru it entirely yet but I've left some initial comments.
|
This looks like it should be two different patches, one for fptrunc and one for fpext. The fpext is the easier of the two IIRC, which should hopefully be a "lower" like we do for other extends. Trunc is more difficult but I don't think we can split it into two different trunc nodes and "fix" them later - we might need to introduce a new node type for the round to odd. |
00a2a15 to
0efa366
Compare
0efa366 to
411afc0
Compare
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - Can we add something to explain the new opcode in llvm/docs/GlobalISel/GenericOpcode.rst
|
|
||
| Convert a floating point value to a narrower type. | ||
|
|
||
| G_FPTRUNC_ODD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should just use G_INTRINSIC_FPTRUNC_ROUND instead of introducing this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. That is a new one to me and looks like an AMD-ism. We don't really have the instructions to easily support that for any other rounding modes, and it doesn't even support odd rounding modes yet. Considering we don't have a great way to conditionally legalize intrinsics like that, a separate instruction sounds like a better approach for us. (We can always change that in the future if needed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not an AMDism, and you can set the legalize rule based on the immediate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am vetoing the new opcode. It is identical to FPTRUNC_ROUND with whatever immediate value corresponds to round odd
It is not an AMDism, and you can set the legalize rule based on the immediate
By AMDism I just meant that it is only handled on AMD at the moment, and there isn't an instruction or very good lowering for it on AArch64.
Round to Odd isn't a standard ieee rounding mode, so is not something that llvm.fptrunc.round supports. It is similar to "to nearest, ties to even", but the ties go to odd, which apparently allows rounding f64->f16 via f32 without introducing double rounding errors. Adding support to FPTRUNC_ROUND just for this one case seems messy for the AArch64 backend, and is not the right direction to take.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per the most recent message here, I have removed the new opcode & instead rely on an AArch64 specific opcode. I have thus moved the lowering to a custom legalizer function.
|
|
||
| Convert a floating point value to a narrower type. | ||
|
|
||
| G_FPTRUNC_ODD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. That is a new one to me and looks like an AMD-ism. We don't really have the instructions to easily support that for any other rounding modes, and it doesn't even support odd rounding modes yet. Considering we don't have a great way to conditionally legalize intrinsics like that, a separate instruction sounds like a better approach for us. (We can always change that in the future if needed).
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried to give the round-to-odd codegen a test to make sure it was really equivalent. I wasn't able to prove anything as the input type is too large, but I ran tests for quite a while without finding any issues.
Can you rebase this?
arsenm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am vetoing the new opcode. It is identical to FPTRUNC_ROUND with whatever immediate value corresponds to round odd
|
|
||
| Convert a floating point value to a narrower type. | ||
|
|
||
| G_FPTRUNC_ODD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not an AMDism, and you can set the legalize rule based on the immediate
f004ac1 to
3513809
Compare
| assert(SrcTy.isFixedVector() && isPowerOf2_32(SrcTy.getNumElements()) && | ||
| "Expected a power of 2 elements"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this work with a "multiple of 2", not a "power of 2"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do but currently the legalizer widens the fptrunc src to the next power of 2, meaning we can keep this simple if we only expect powers of 2. Otherwise, we might have to pad vectors so we can later concat them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh - that fine for now then. We should go through at some point and check non-power2 vector types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I add a todo comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No thats fine, we need to go through all of them I think.
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM
| assert(SrcTy.isFixedVector() && isPowerOf2_32(SrcTy.getNumElements()) && | ||
| "Expected a power of 2 elements"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No thats fine, we need to go through all of them I think.
…63398) This commit improves the lowering of vectors of fp16 when truncating and (previously) extending. Truncating has to be handled in a specific way to avoid double rounding.
…63398) This commit improves the lowering of vectors of fp16 when truncating and (previously) extending. Truncating has to be handled in a specific way to avoid double rounding.
…63398) This commit improves the lowering of vectors of fp16 when truncating and (previously) extending. Truncating has to be handled in a specific way to avoid double rounding.
This commit improves the lowering of vectors of fp16 when truncating and (previously) extending. Truncating has to be handled in a specific way to avoid double rounding.